{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exporting model from Chainer to ONNX\n",
    "\n",
    "In this tutorial, we describe how to use ONNX-Chainer to convert a model defined in Chainer into the ONNX format.\n",
    "\n",
    "ONNX export is provided as a separate package [onnx-chainer](https://github.com/chainer/onnx-chainer). You can install it via pip like this:\n",
    "\n",
    "```\n",
    "pip install onnx-chainer\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`onnx_chainer` provides `export` function that takes a Chainer model and its expected arguments given to `__call__` method of the model. It executes a forward pass once with the given model object and arguments to construct a computational graph. Because Chainer is the first deep learning framework that proposed Define-by-Run approach, the computational graph for backward computation is constructed on-the-fly. `onnx-chainer` is a trace-based exporter, so it needs to run a model once before converting the structure into ONNX.\n",
    "\n",
    "`onnx_chainer.export()` function have some other options for example `filename` which is to save the converted model to a disk."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Limitations\n",
    "\n",
    "The [onnx-chainer](https://github.com/chainer/onnx-chainer) currently does not support exporting dynamic models which change their behavior depending on input data, because ONNX format currently cannot represent such dynamic behavior.\n",
    "\n",
    "Additionally, some ONNX operators, for example, `Reshape`, keeps the explicit batch size in the shape of dummy input data, so that the exported ONNX model will be runnable only with the same batch size. You may want to run the model with different batch size at inference time, but it might fail due to this limitation. To modify the batch size in the `Reshape` operator of ONNX, you need to modify the value by hand after exporting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Help on function export in module onnx_chainer.export:\n",
      "\n",
      "export(model, args, filename=None, export_params=True, graph_name='Graph', save_text=False)\n",
      "    Export function for chainer.Chain in ONNX format.\n",
      "    \n",
      "    This function performs a forward computation of the given\n",
      "    :class:`~chainer.Chain`, ``model``, by passing the given argments ``args``\n",
      "    directly. It means, the output :class:`~chainer.Variable` object ``y`` to\n",
      "    make the computational graph will be created by:\n",
      "    \n",
      "    y = model(*args)\n",
      "    \n",
      "    Args:\n",
      "        model (~chainer.Chain): The model object you want to export in ONNX\n",
      "            format. It should have :meth:`__call__` method because the second\n",
      "            argment ``args`` is directly given to the model by the ``[]``\n",
      "            accessor.\n",
      "        args (list or dict): The argments which are given to the model\n",
      "            directly.\n",
      "        filename (str or file-like object): The filename used for saving the\n",
      "            resulting ONNX model. If None, nothing is saved to the disk.\n",
      "        export_params (bool): If True, this function exports all the parameters\n",
      "            included in the given model at the same time. If False, the\n",
      "            exported ONNX model doesn't include any parameter values.\n",
      "        graph_name (str): A string to be used for the ``name`` field of the\n",
      "            graph in the exported ONNX model.\n",
      "        save_text (bool): If True, the text format of the output ONNX model is\n",
      "            also saved with ``.txt`` extention.\n",
      "    \n",
      "    Returns:\n",
      "        A ONNX model object.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import onnx_chainer\n",
    "help(onnx_chainer.export)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export VGG16 into ONNX\n",
    "\n",
    "Chainer provides VGG16 implementation with pre-trained weights, so let's see how the network that is a `chainer.Chain` object can be exported into ONNX."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import chainer\n",
    "import chainer.links as L\n",
    "import onnx_chainer\n",
    "\n",
    "model = L.VGG16Layers()\n",
    "\n",
    "# Pseudo input\n",
    "x = np.zeros((1, 3, 224, 224), dtype=np.float32)\n",
    "\n",
    "# Don't forget to set train flag off!\n",
    "chainer.config.train = False\n",
    "\n",
    "onnx_chainer.export(model, x, filename='output/VGG16.onnx')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**That's all.**\n",
    "\n",
    "Now you can find the exported ONNX binary named `VGG16.onnx` under the `output` dir."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inspecting the exported model\n",
    "\n",
    "You can use ONNX tooling to check the validity of the exported ONNX model and inspect the details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "graph Graph (\n",
      "  %4704767504[FLOAT, 1x3x224x224]\n",
      ") initializers (\n",
      "  %/conv5_1/b[FLOAT, 512]\n",
      "  %/conv5_1/W[FLOAT, 512x512x3x3]\n",
      "  %/conv5_3/b[FLOAT, 512]\n",
      "  %/conv5_3/W[FLOAT, 512x512x3x3]\n",
      "  %/conv3_2/b[FLOAT, 256]\n",
      "  %/conv3_2/W[FLOAT, 256x256x3x3]\n",
      "  %/fc6/b[FLOAT, 4096]\n",
      "  %/fc6/W[FLOAT, 4096x25088]\n",
      "  %/fc8/b[FLOAT, 1000]\n",
      "  %/fc8/W[FLOAT, 1000x4096]\n",
      "  %/conv1_2/b[FLOAT, 64]\n",
      "  %/conv1_2/W[FLOAT, 64x64x3x3]\n",
      "  %/conv2_1/b[FLOAT, 128]\n",
      "  %/conv2_1/W[FLOAT, 128x64x3x3]\n",
      "  %/conv4_1/b[FLOAT, 512]\n",
      "  %/conv4_1/W[FLOAT, 512x256x3x3]\n",
      "  %/conv2_2/b[FLOAT, 128]\n",
      "  %/conv2_2/W[FLOAT, 128x128x3x3]\n",
      "  %/conv3_3/b[FLOAT, 256]\n",
      "  %/conv3_3/W[FLOAT, 256x256x3x3]\n",
      "  %/conv1_1/b[FLOAT, 64]\n",
      "  %/conv1_1/W[FLOAT, 64x3x3x3]\n",
      "  %/conv3_1/b[FLOAT, 256]\n",
      "  %/conv3_1/W[FLOAT, 256x128x3x3]\n",
      "  %/conv4_3/b[FLOAT, 512]\n",
      "  %/conv4_3/W[FLOAT, 512x512x3x3]\n",
      "  %/conv5_2/b[FLOAT, 512]\n",
      "  %/conv5_2/W[FLOAT, 512x512x3x3]\n",
      "  %/fc7/b[FLOAT, 4096]\n",
      "  %/fc7/W[FLOAT, 4096x4096]\n",
      "  %/conv4_2/b[FLOAT, 512]\n",
      "  %/conv4_2/W[FLOAT, 512x512x3x3]\n",
      ") {\n",
      "  %4704767392 = Conv[dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%4704767504, %/conv1_1/W, %/conv1_1/b)\n",
      "  %4704767840 = Relu(%4704767392)\n",
      "  %4704767728 = Conv[dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%4704767840, %/conv1_2/W, %/conv1_2/b)\n",
      "  %4704767952 = Relu(%4704767728)\n",
      "  %4533214120 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%4704767952)\n",
      "  %4555989232 = Conv[dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%4533214120, %/conv2_1/W, %/conv2_1/b)\n",
      "  %4697163368 = Relu(%4555989232)\n",
      "  %4704818960 = Conv[dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%4697163368, %/conv2_2/W, %/conv2_2/b)\n",
      "  %4704819184 = Relu(%4704818960)\n",
      "  %4704819408 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%4704819184)\n",
      "  %4704819520 = Conv[dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%4704819408, %/conv3_1/W, %/conv3_1/b)\n",
      "  %4704819632 = Relu(%4704819520)\n",
      "  %4704819744 = Conv[dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%4704819632, %/conv3_2/W, %/conv3_2/b)\n",
      "  %4704819856 = Relu(%4704819744)\n",
      "  %4704819968 = Conv[dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%4704819856, %/conv3_3/W, %/conv3_3/b)\n",
      "  %4704820080 = Relu(%4704819968)\n",
      "  %4704820192 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%4704820080)\n",
      "  %4704820304 = Conv[dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%4704820192, %/conv4_1/W, %/conv4_1/b)\n",
      "  %4704820416 = Relu(%4704820304)\n",
      "  %4704820528 = Conv[dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%4704820416, %/conv4_2/W, %/conv4_2/b)\n",
      "  %4704820640 = Relu(%4704820528)\n",
      "  %4704820752 = Conv[dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%4704820640, %/conv4_3/W, %/conv4_3/b)\n",
      "  %4704820864 = Relu(%4704820752)\n",
      "  %4704820976 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%4704820864)\n",
      "  %4704821088 = Conv[dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%4704820976, %/conv5_1/W, %/conv5_1/b)\n",
      "  %4704821200 = Relu(%4704821088)\n",
      "  %4704895216 = Conv[dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%4704821200, %/conv5_2/W, %/conv5_2/b)\n",
      "  %4704895440 = Relu(%4704895216)\n",
      "  %4704895664 = Conv[dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%4704895440, %/conv5_3/W, %/conv5_3/b)\n",
      "  %4704895888 = Relu(%4704895664)\n",
      "  %4704896168 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%4704895888)\n",
      "  %4704896448 = Reshape[shape = [1, -1]](%4704896168)\n",
      "  %4704896672 = FC[axis = 1, axis_w = 1](%4704896448, %/fc6/W, %/fc6/b)\n",
      "  %4704896784 = Relu(%4704896672)\n",
      "  %4704897008 = FC[axis = 1, axis_w = 1](%4704896784, %/fc7/W, %/fc7/b)\n",
      "  %4704897232 = Relu(%4704897008)\n",
      "  %4704897456 = FC[axis = 1, axis_w = 1](%4704897232, %/fc8/W, %/fc8/b)\n",
      "  %4704897680 = Softmax[axis = 1](%4704897456)\n",
      "  return %4704897680\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "import onnx\n",
    "\n",
    "# Load the ONNX model\n",
    "model = onnx.load(\"output/VGG16.onnx\")\n",
    "\n",
    "# Check that the IR is well formed\n",
    "onnx.checker.check_model(model)\n",
    "\n",
    "# Print a human readable representation of the graph\n",
    "print(onnx.helper.printable_graph(model.graph))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see all the parameters and layers in the VGG16 model here. The actual values those parameters have are stored in `model.graph.initializers`."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
